/**
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import { GenkitError, MessageData, z, type Genkit } from 'genkit';
import {
  getBasicUsageStats,
  modelRef,
  type GenerateRequest,
  type ModelAction,
  type ModelInfo,
  type ModelReference,
} from 'genkit/model';
import { getApiKeyFromEnvVar } from './common.js';
import { predictModel } from './predict.js';

export type KNOWN_IMAGEN_MODELS = 'imagen-3.0-generate-002';

/**
 * See https://ai.google.dev/gemini-api/docs/image-generation#imagen-model
 */
export const ImagenConfigSchema = z
  .object({
    numberOfImages: z
      .number()
      .describe(
        'The number of images to generate, from 1 to 4 (inclusive). The default is 1.'
      )
      .optional(),
    aspectRatio: z
      .enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
      .describe('Desired aspect ratio of the output image.')
      .optional(),
    personGeneration: z
      .enum(['dont_allow', 'allow_adult', 'allow_all'])
      .describe(
        'Control if/how images of people will be generated by the model.'
      )
      .optional(),
  })
  .passthrough();

interface ImagenParameters {
  sampleCount?: number;
  aspectRatio?: string;
  personGeneration?: string;
}

function toParameters(
  request: GenerateRequest<typeof ImagenConfigSchema>
): ImagenParameters {
  const out = {
    sampleCount: request.config?.numberOfImages ?? 1,
    ...request?.config,
  };

  for (const k in out) {
    if (!out[k]) delete out[k];
  }

  return out;
}

function extractText(request: GenerateRequest) {
  return request.messages
    .at(-1)!
    .content.map((c) => c.text || '')
    .join('');
}

function extractBaseImage(request: GenerateRequest): string | undefined {
  return request.messages
    .at(-1)
    ?.content.find((p) => !!p.media)
    ?.media?.url.split(',')[1];
}

interface ImagenPrediction {
  predictions: { bytesBase64Encoded: string; mimeType: string }[];
}

interface ImagenInstance {
  prompt: string;
  image?: { bytesBase64Encoded: string };
  mask?: { image?: { bytesBase64Encoded: string } };
}

export const GENERIC_IMAGEN_INFO = {
  label: `Google AI - Generic Imagen`,
  supports: {
    media: true,
    multiturn: false,
    tools: false,
    systemRole: false,
    output: ['media'],
  },
} as ModelInfo;

export function defineImagenModel(
  ai: Genkit,
  name: string,
  apiKey?: string | false
): ModelAction {
  if (apiKey !== false) {
    apiKey = apiKey || getApiKeyFromEnvVar();
    if (!apiKey) {
      throw new GenkitError({
        status: 'FAILED_PRECONDITION',
        message:
          'Please pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.\n' +
          'For more details see https://genkit.dev/docs/plugins/google-genai',
      });
    }
  }
  const modelName = `googleai/${name}`;
  const model: ModelReference<z.ZodTypeAny> = modelRef({
    name: modelName,
    info: {
      ...GENERIC_IMAGEN_INFO,
      label: `Google AI - ${name}`,
    },
    configSchema: ImagenConfigSchema,
  });

  return ai.defineModel(
    {
      name: modelName,
      ...model.info,
      configSchema: ImagenConfigSchema,
    },
    async (request) => {
      const instance: ImagenInstance = {
        prompt: extractText(request),
      };
      const baseImage = extractBaseImage(request);
      if (baseImage) {
        instance.image = { bytesBase64Encoded: baseImage };
      }

      const predictClient = predictModel<
        ImagenInstance,
        ImagenPrediction,
        ImagenParameters
      >(model.version || name, apiKey as string, 'predict');
      const response = await predictClient([instance], toParameters(request));

      if (!response.predictions || response.predictions.length == 0) {
        throw new Error(
          'Model returned no predictions. Possibly due to content filters.'
        );
      }

      const message = {
        role: 'model',
        content: [],
      } as MessageData;

      response.predictions.forEach((p, i) => {
        const b64data = p.bytesBase64Encoded;
        const mimeType = p.mimeType;
        message.content.push({
          media: {
            url: `data:${mimeType};base64,${b64data}`,
            contentType: mimeType,
          },
        });
      });
      return {
        finishReason: 'stop',
        message,
        usage: getBasicUsageStats(request.messages, message),
        custom: response,
      };
    }
  );
}
